#!/usr/bin/env python
import numpy as np, pandas as pd, argparse, yaml, sys

RG_BINS = [1.5, 3.0, 5.0, 8.0, 12.0]
RG_LABELS = ["1.5–3.0","3.0–5.0","5.0–8.0","8.0–12.0"]
M_BINS = [10.2, 10.5, 10.8, 11.1]
M_LABELS = ["10.2–10.5","10.5–10.8","10.8–11.1"]

def label_bin(val, edges, labels):
    for i in range(len(edges)-1):
        lo, hi = edges[i], edges[i+1]
        last = (i == len(edges)-2)
        if (val >= lo) and ((val < hi) or (last and val <= hi)):
            return labels[i]
    return np.nan

def enforce_bins(df):
    if ("R_G_bin" not in df.columns) or df["R_G_bin"].isna().any():
        df["R_G_bin"] = df["R_G_kpc"].apply(lambda v: label_bin(v, RG_BINS, RG_LABELS))
    if ("Mstar_bin" not in df.columns) or df["Mstar_bin"].isna().any():
        df["Mstar_bin"] = df["Mstar_log10"].apply(lambda v: label_bin(v, M_BINS, M_LABELS))
    return df

def compute_x(df, R_MW_kpc, log10_M_MW, eta):
    scale = (10.0 ** (df["Mstar_log10"] - log10_M_MW)) ** eta
    denom = R_MW_kpc * scale.replace(0, np.nan)
    x = df["R_G_kpc"] / denom
    return x

def summarize_by_stack(df, x, x_threshold, around_width):
    rows = []
    around_lo, around_hi = x_threshold*(1-around_width), x_threshold*(1+around_width)
    for (mb, rb), g in df.groupby(["Mstar_bin","R_G_bin"], dropna=True):
        idx = g.index
        gx = x.loc[idx].dropna()
        if len(gx) == 0:
            rows.append(dict(Mstar_bin=mb, R_G_bin=rb, n_lenses=len(g),
                             x_median=np.nan, x_p25=np.nan, x_p75=np.nan,
                             frac_x_gt=np.nan, frac_x_around=np.nan))
            continue
        rows.append(dict(
            Mstar_bin=mb, R_G_bin=rb, n_lenses=len(g),
            x_median=float(gx.median()),
            x_p25=float(gx.quantile(0.25)),
            x_p75=float(gx.quantile(0.75)),
            frac_x_gt=float((gx >= x_threshold).mean()),
            frac_x_around=float(((gx >= around_lo) & (gx <= around_hi)).mean())
        ))
    return pd.DataFrame(rows)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--lenses", required=True)
    ap.add_argument("--config", required=True)
    ap.add_argument("--out", required=True)
    args = ap.parse_args()

    cfg = yaml.safe_load(open(args.config))
    grid_R = list(cfg.get("R_MW_kpc_grid",[6.0]))
    grid_eta = list(cfg.get("eta_grid",[0.0]))
    log10_M_MW = float(cfg.get("log10_M_MW", 10.7))
    x_threshold = float(cfg.get("x_threshold", 1.0))
    around_width = float(cfg.get("around_width", 0.25))

    L = pd.read_csv(args.lenses)
    need = ["R_G_kpc","Mstar_log10"]
    for c in need:
        if c not in L.columns:
            print(f"[error] lenses file missing column '{c}'", file=sys.stderr)
            sys.exit(2)
    L = enforce_bins(L)

    all_tables = []
    for R_MW in grid_R:
        for eta in grid_eta:
            x = compute_x(L, R_MW_kpc=R_MW, log10_M_MW=log10_M_MW, eta=eta)
            T = summarize_by_stack(L, x, x_threshold, around_width)
            T["R_MW_kpc"] = float(R_MW)
            T["eta"] = float(eta)
            all_tables.append(T)
    OUT = pd.concat(all_tables, ignore_index=True)
    OUT.to_csv(args.out, index=False)
    print(f"[info] wrote {args.out} with {len(OUT)} rows over grid {len(grid_R)}x{len(grid_eta)}")

if __name__ == "__main__":
    main()
